{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2018 The TensorFlow Authors. All Rights Reserved.  \n",
    "  \n",
    " Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    " you may not use this file except in compliance with the License.  \n",
    " You may obtain a copy of the License at  \n",
    "  \n",
    "     http://www.apache.org/licenses/LICENSE-2.0  \n",
    "  \n",
    " Unless required by applicable law or agreed to in writing, software  \n",
    " distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    " See the License for the specific language governing permissions and  \n",
    " limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the dependency .py files, if any.\n",
    "! git clone https://github.com/GoogleCloudPlatform/cloudml-samples.git\n",
    "! cp cloudml-samples/census/tensorflowcore/trainer/* .\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"A Feed forward neural network using TensorFlow Core APIs.\n",
    "\n",
    "It implements a binary classifier for Census Income Dataset using both single\n",
    "and distributed node cluster.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import json\n",
    "import os\n",
    "import threading\n",
    "import six\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.ops import control_flow_ops\n",
    "from tensorflow.python.ops import lookup_ops\n",
    "from tensorflow.python.ops import variables\n",
    "from tensorflow.python.saved_model import signature_constants as sig_constants\n",
    "import model as model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EvaluationRunHook(tf.train.SessionRunHook):\n",
    "  \"\"\"EvaluationRunHook performs continuous evaluation of the model.\n",
    "\n",
    "  Args:\n",
    "    checkpoint_dir (string): Dir to store model checkpoints\n",
    "    metric_dir (string): Dir to store metrics like accuracy and auroc\n",
    "    graph (tf.Graph): Evaluation graph\n",
    "    eval_frequency (int): Frequency of evaluation every n train steps\n",
    "    eval_steps (int): Evaluation steps to be performed\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(self,\n",
    "               checkpoint_dir,\n",
    "               metric_dict,\n",
    "               graph,\n",
    "               eval_frequency,\n",
    "               eval_steps=None,\n",
    "               **kwargs):\n",
    "\n",
    "    self._eval_steps = eval_steps\n",
    "    self._checkpoint_dir = checkpoint_dir\n",
    "    self._kwargs = kwargs\n",
    "    self._eval_every = eval_frequency\n",
    "    self._latest_checkpoint = None\n",
    "    self._checkpoints_since_eval = 0\n",
    "    self._graph = graph\n",
    "\n",
    "    # With the graph object as default graph.\n",
    "    # See https://www.tensorflow.org/api_docs/python/tf/Graph#as_default\n",
    "    # Adds ops to the graph object\n",
    "    with graph.as_default():\n",
    "      value_dict, update_dict = tf.contrib.metrics.aggregate_metric_map(\n",
    "          metric_dict)\n",
    "\n",
    "      # Op that creates a Summary protocol buffer by merging summaries\n",
    "      self._summary_op = tf.summary.merge([\n",
    "          tf.summary.scalar(name, value_op)\n",
    "          for name, value_op in six.iteritems(value_dict)\n",
    "      ])\n",
    "\n",
    "      # Saver class add ops to save and restore\n",
    "      # variables to and from checkpoint\n",
    "      self._saver = tf.train.Saver()\n",
    "\n",
    "      # Creates a global step to contain a counter for\n",
    "      # the global training step\n",
    "      self._gs = tf.train.get_or_create_global_step()\n",
    "\n",
    "      self._final_ops_dict = value_dict\n",
    "      self._eval_ops = update_dict.values()\n",
    "\n",
    "    # MonitoredTrainingSession runs hooks in background threads\n",
    "    # and it doesn't wait for the thread from the last session.run()\n",
    "    # call to terminate to invoke the next hook, hence locks.\n",
    "    self._eval_lock = threading.Lock()\n",
    "    self._checkpoint_lock = threading.Lock()\n",
    "    self._file_writer = tf.summary.FileWriter(\n",
    "        os.path.join(checkpoint_dir, 'eval'), graph=graph)\n",
    "\n",
    "  def after_run(self, run_context, run_values):\n",
    "    # Always check for new checkpoints in case a single evaluation\n",
    "    # takes longer than checkpoint frequency and _eval_every is >1\n",
    "    self._update_latest_checkpoint()\n",
    "\n",
    "    if self._eval_lock.acquire(False):\n",
    "      try:\n",
    "        if self._checkpoints_since_eval > self._eval_every:\n",
    "          self._checkpoints_since_eval = 0\n",
    "          self._run_eval()\n",
    "      finally:\n",
    "        self._eval_lock.release()\n",
    "\n",
    "  def _update_latest_checkpoint(self):\n",
    "    \"\"\"Update the latest checkpoint file created in the output dir.\"\"\"\n",
    "    if self._checkpoint_lock.acquire(False):\n",
    "      try:\n",
    "        latest = tf.train.latest_checkpoint(self._checkpoint_dir)\n",
    "        if latest != self._latest_checkpoint:\n",
    "          self._checkpoints_since_eval += 1\n",
    "          self._latest_checkpoint = latest\n",
    "      finally:\n",
    "        self._checkpoint_lock.release()\n",
    "\n",
    "  def end(self, session):\n",
    "    \"\"\"Called at then end of session to make sure we always evaluate.\"\"\"\n",
    "    self._update_latest_checkpoint()\n",
    "\n",
    "    with self._eval_lock:\n",
    "      self._run_eval()\n",
    "\n",
    "  def _run_eval(self):\n",
    "    \"\"\"Run model evaluation and generate summaries.\"\"\"\n",
    "    coord = tf.train.Coordinator(clean_stop_exception_types=(\n",
    "        tf.errors.CancelledError, tf.errors.OutOfRangeError))\n",
    "\n",
    "    with tf.Session(graph=self._graph) as session:\n",
    "      # Restores previously saved variables from latest checkpoint\n",
    "      self._saver.restore(session, self._latest_checkpoint)\n",
    "\n",
    "      session.run([\n",
    "          tf.tables_initializer(),\n",
    "          tf.local_variables_initializer()])\n",
    "      tf.train.start_queue_runners(coord=coord, sess=session)\n",
    "      train_step = session.run(self._gs)\n",
    "\n",
    "      tf.logging.info('Starting Evaluation For Step: {}'.format(train_step))\n",
    "      with coord.stop_on_exception():\n",
    "        eval_step = 0\n",
    "        while not coord.should_stop() and (self._eval_steps is None or\n",
    "                                           eval_step < self._eval_steps):\n",
    "          summaries, final_values, _ = session.run(\n",
    "              [self._summary_op, self._final_ops_dict, self._eval_ops])\n",
    "          if eval_step % 100 == 0:\n",
    "            tf.logging.info('On Evaluation Step: {}'.format(eval_step))\n",
    "          eval_step += 1\n",
    "\n",
    "      # Write the summaries\n",
    "      self._file_writer.add_summary(summaries, global_step=train_step)\n",
    "      self._file_writer.flush()\n",
    "      tf.logging.info(final_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run(target, cluster_spec, is_chief, args):\n",
    "\n",
    "  \"\"\"Runs the training and evaluation graph.\n",
    "\n",
    "  Args:\n",
    "    target (str): Tensorflow server target.\n",
    "    cluster_spec: (cluster spec) Cluster specification.\n",
    "    is_chief (bool): Boolean flag to specify a chief server.\n",
    "    args (args): Input Arguments.\n",
    "  \"\"\"\n",
    "\n",
    "  # Calculate the number of hidden units\n",
    "  hidden_units = [\n",
    "      max(2, int(args.first_layer_size * args.scale_factor**i))\n",
    "      for i in range(args.num_layers)\n",
    "  ]\n",
    "\n",
    "  # If the server is chief which is `master`\n",
    "  # In between graph replication Chief is one node in\n",
    "  # the cluster with extra responsibility and by default\n",
    "  # is worker task zero. We have assigned master as the chief.\n",
    "  #\n",
    "  # See https://youtu.be/la_M6bCV91M?t=1203 for details on\n",
    "  # distributed TensorFlow and motivation about chief.\n",
    "  if is_chief:\n",
    "    tf.logging.info('Created DNN hidden units {}'.format(hidden_units))\n",
    "    evaluation_graph = tf.Graph()\n",
    "    with evaluation_graph.as_default():\n",
    "\n",
    "      # Features and label tensors\n",
    "      features, labels = model.input_fn(\n",
    "        args.eval_files,\n",
    "          num_epochs=None if args.eval_steps else 1,\n",
    "          batch_size=args.eval_batch_size,\n",
    "          shuffle=False\n",
    "      )\n",
    "      # Accuracy and AUROC metrics\n",
    "      # model.model_fn returns the dict when EVAL mode\n",
    "      metric_dict = model.model_fn(\n",
    "          model.EVAL,\n",
    "          features.copy(),\n",
    "          labels,\n",
    "          hidden_units=hidden_units,\n",
    "          learning_rate=args.learning_rate\n",
    "      )\n",
    "\n",
    "    hooks = [EvaluationRunHook(\n",
    "        args.job_dir,\n",
    "        metric_dict,\n",
    "        evaluation_graph,\n",
    "        args.eval_frequency,\n",
    "        eval_steps=args.eval_steps,\n",
    "    )]\n",
    "  else:\n",
    "    hooks = []\n",
    "\n",
    "  # Create a new graph and specify that as default.\n",
    "  with tf.Graph().as_default():\n",
    "    # Placement of ops on devices using replica device setter\n",
    "    # which automatically places the parameters on the `ps` server\n",
    "    # and the `ops` on the workers.\n",
    "    #\n",
    "    # See:\n",
    "    # https://www.tensorflow.org/api_docs/python/tf/train/replica_device_setter\n",
    "    with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):\n",
    "\n",
    "      # Features and label tensors as read using filename queue.\n",
    "      features, labels = model.input_fn(\n",
    "          args.train_files,\n",
    "          num_epochs=args.num_epochs,\n",
    "          batch_size=args.train_batch_size\n",
    "      )\n",
    "\n",
    "      # Returns the training graph and global step tensor.\n",
    "      train_op, global_step_tensor = model.model_fn(\n",
    "          model.TRAIN,\n",
    "          features.copy(),\n",
    "          labels,\n",
    "          hidden_units=hidden_units,\n",
    "          learning_rate=args.learning_rate\n",
    "      )\n",
    "\n",
    "    # Creates a MonitoredSession for training.\n",
    "    # MonitoredSession is a Session-like object that handles\n",
    "    # initialization, recovery and hooks\n",
    "    # https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession\n",
    "    with tf.train.MonitoredTrainingSession(master=target,\n",
    "                                           is_chief=is_chief,\n",
    "                                           checkpoint_dir=args.job_dir,\n",
    "                                           hooks=hooks,\n",
    "                                           save_checkpoint_secs=20,\n",
    "                                           save_summaries_steps=50) as session:\n",
    "      # Global step to keep track of global number of steps particularly in\n",
    "      # distributed setting\n",
    "      step = global_step_tensor.eval(session=session)\n",
    "\n",
    "      # Run the training graph which returns the step number as tracked by\n",
    "      # the global step tensor.\n",
    "      # When train epochs is reached, session.should_stop() will be true.\n",
    "      while (args.train_steps is None or\n",
    "             step < args.train_steps) and not session.should_stop():\n",
    "        step, _ = session.run([global_step_tensor, train_op])\n",
    "\n",
    "    # Find the filename of the latest saved checkpoint file\n",
    "    latest_checkpoint = tf.train.latest_checkpoint(args.job_dir)\n",
    "\n",
    "    # Only perform this if chief\n",
    "    if is_chief:\n",
    "      build_and_run_exports(latest_checkpoint,\n",
    "                            args.job_dir,\n",
    "                            model.SERVING_INPUT_FUNCTIONS[args.export_format],\n",
    "                            hidden_units)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main_op():\n",
    "  init_local = variables.local_variables_initializer()\n",
    "  init_tables = lookup_ops.tables_initializer()\n",
    "  return control_flow_ops.group(init_local, init_tables)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_and_run_exports(latest, job_dir, serving_input_fn, hidden_units):\n",
    "  \"\"\"Given the latest checkpoint file export the saved model.\n",
    "\n",
    "  Args:\n",
    "    latest (str): Latest checkpoint file.\n",
    "    job_dir (str): Location of checkpoints and model files.\n",
    "    serving_input_fn (str): Serving Function\n",
    "    hidden_units (list): Number of hidden units.\n",
    "  \"\"\"\n",
    "\n",
    "  prediction_graph = tf.Graph()\n",
    "  # Create exporter.\n",
    "  exporter = tf.saved_model.builder.SavedModelBuilder(\n",
    "      os.path.join(job_dir, 'export'))\n",
    "  with prediction_graph.as_default():\n",
    "    features, inputs_dict = serving_input_fn()\n",
    "    prediction_dict = model.model_fn(\n",
    "        model.PREDICT,\n",
    "        features.copy(),\n",
    "        None,  # labels\n",
    "        hidden_units=hidden_units,\n",
    "        learning_rate=None  # learning_rate unused in prediction mode\n",
    "    )\n",
    "    saver = tf.train.Saver()\n",
    "\n",
    "    inputs_info = {\n",
    "        name: tf.saved_model.utils.build_tensor_info(tensor)\n",
    "        for name, tensor in six.iteritems(inputs_dict)\n",
    "    }\n",
    "    output_info = {\n",
    "        name: tf.saved_model.utils.build_tensor_info(tensor)\n",
    "        for name, tensor in six.iteritems(prediction_dict)\n",
    "    }\n",
    "    signature_def = tf.saved_model.signature_def_utils.build_signature_def(\n",
    "        inputs=inputs_info,\n",
    "        outputs=output_info,\n",
    "        method_name=sig_constants.PREDICT_METHOD_NAME\n",
    "    )\n",
    "\n",
    "  with tf.Session(graph=prediction_graph) as session:\n",
    "    session.run([tf.local_variables_initializer(), tf.tables_initializer()])\n",
    "    saver.restore(session, latest)\n",
    "    exporter.add_meta_graph_and_variables(\n",
    "        session,\n",
    "        tags=[tf.saved_model.tag_constants.SERVING],\n",
    "        signature_def_map={\n",
    "            sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def\n",
    "        },\n",
    "        legacy_init_op=main_op()\n",
    "    )\n",
    "  exporter.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_evaluate(args):\n",
    "  \"\"\"Parse TF_CONFIG to cluster_spec and call run() method.\n",
    "\n",
    "  TF_CONFIG environment variable is available when running using\n",
    "  gcloud either locally or on cloud. It has all the information required\n",
    "  to create a ClusterSpec which is important for running distributed code.\n",
    "\n",
    "  Args:\n",
    "    args (args): Input arguments.\n",
    "  \"\"\"\n",
    "\n",
    "  tf_config = os.environ.get('TF_CONFIG')\n",
    "  # If TF_CONFIG is not available run local.\n",
    "  if not tf_config:\n",
    "    return run(target='', cluster_spec=None, is_chief=True, args=args)\n",
    "\n",
    "  tf_config_json = json.loads(tf_config)\n",
    "  cluster = tf_config_json.get('cluster')\n",
    "  job_name = tf_config_json.get('task', {}).get('type')\n",
    "  task_index = tf_config_json.get('task', {}).get('index')\n",
    "\n",
    "  # If cluster information is empty run local.\n",
    "  if job_name is None or task_index is None:\n",
    "    return run(target='', cluster_spec=None, is_chief=True, args=args)\n",
    "\n",
    "  cluster_spec = tf.train.ClusterSpec(cluster)\n",
    "  server = tf.train.Server(cluster_spec,\n",
    "                           job_name=job_name,\n",
    "                           task_index=task_index)\n",
    "\n",
    "  # Wait for incoming connections forever.\n",
    "  # Worker ships the graph to the ps server.\n",
    "  # The ps server manages the parameters of the model.\n",
    "  #\n",
    "  # See a detailed video on distributed TensorFlow\n",
    "  # https://www.youtube.com/watch?v=la_M6bCV91M\n",
    "  if job_name == 'ps':\n",
    "    server.join()\n",
    "    return\n",
    "  elif job_name in ['master', 'worker']:\n",
    "    return run(server.target, cluster_spec, is_chief=(job_name == 'master'),\n",
    "               args=args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "# Input Arguments\n",
    "parser.add_argument(\n",
    "    '--train-files',\n",
    "    nargs='+',\n",
    "    help='Training files local or GCS',\n",
    "    default='gs://cloud-samples-data/ml-engine/census/data/adult.data.csv')\n",
    "parser.add_argument(\n",
    "    '--eval-files',\n",
    "    nargs='+',\n",
    "    help='Evaluation files local or GCS',\n",
    "    default='gs://cloud-samples-data/ml-engine/census/data/adult.test.csv')\n",
    "parser.add_argument(\n",
    "    '--job-dir',\n",
    "    type=str,\n",
    "    help=\"\"\"GCS or local dir for checkpoints, exports, and summaries.\n",
    "      Use an existing directory to load a trained model, or a new directory\n",
    "      to retrain\"\"\",\n",
    "    default='/tmp/census-tensorflowcore')\n",
    "parser.add_argument(\n",
    "    '--train-steps',\n",
    "    type=int,\n",
    "    help='Maximum number of training steps to perform.')\n",
    "parser.add_argument(\n",
    "    '--eval-steps',\n",
    "    help=\"\"\"Number of steps to run evalution for at each checkpoint.\n",
    "    If unspecified, will run for 1 full epoch over training data\"\"\",\n",
    "    default=None,\n",
    "    type=int)\n",
    "parser.add_argument(\n",
    "    '--train-batch-size',\n",
    "    type=int,\n",
    "    default=40,\n",
    "    help='Batch size for training steps')\n",
    "parser.add_argument(\n",
    "    '--eval-batch-size',\n",
    "    type=int,\n",
    "    default=40,\n",
    "    help='Batch size for evaluation steps')\n",
    "parser.add_argument(\n",
    "    '--learning-rate',\n",
    "    type=float,\n",
    "    default=0.003,\n",
    "    help='Learning rate for SGD')\n",
    "parser.add_argument(\n",
    "    '--eval-frequency',\n",
    "    default=50,\n",
    "    help='Perform one evaluation per n steps')\n",
    "parser.add_argument(\n",
    "    '--first-layer-size',\n",
    "    type=int,\n",
    "    default=256,\n",
    "    help='Number of nodes in the first layer of DNN')\n",
    "parser.add_argument(\n",
    "    '--num-layers',\n",
    "    type=int,\n",
    "    default=2,\n",
    "    help='Number of layers in DNN')\n",
    "parser.add_argument(\n",
    "    '--scale-factor',\n",
    "    type=float,\n",
    "    default=0.25,\n",
    "    help=\"\"\"Rate of decay size of layer for Deep Neural Net.\n",
    "    max(2, int(first_layer_size * scale_factor**i)) \"\"\")\n",
    "parser.add_argument(\n",
    "    '--num-epochs',\n",
    "    type=int,\n",
    "    help='Maximum number of epochs on which to train')\n",
    "parser.add_argument(\n",
    "    '--export-format',\n",
    "    help='The input format of the exported SavedModel binary',\n",
    "    choices=['JSON', 'CSV', 'EXAMPLE'],\n",
    "    default='JSON')\n",
    "parser.add_argument(\n",
    "    '--verbosity',\n",
    "    choices=['DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN'],\n",
    "    default='INFO',\n",
    "    help='Set logging verbosity')\n",
    "\n",
    "args, _ = parser.parse_known_args()\n",
    "\n",
    "# Set python level verbosity\n",
    "tf.logging.set_verbosity(args.verbosity)\n",
    "# Set C++ Graph Execution level verbosity\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(\n",
    "    tf.logging.__dict__[args.verbosity] / 10)\n",
    "\n",
    "# Run the training job.\n",
    "train_and_evaluate(args)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
